//+------------------------------------------------------------------+
//|                                         Transfer Learning EA.mq5 |
//|                                          Copyright 2023, Omegafx |
//|                 https://www.mql5.com/en/users/omegajoctan/seller |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omegafx"
#property link      "https://www.mql5.com/en/users/omegajoctan/seller"
#property version   "1.00"

#include <ta.mqh> //similar to ta in Python --> https://www.mql5.com/en/articles/16931
#include <pandas.mqh> //similar to Pandas in Python --> https://www.mql5.com/en/articles/17030
#include <CNN.mqh> //For loading Convolutional Neural networks in ONNX format --> https://www.mql5.com/en/articles/15259
#include <preprocessing.mqh> //For loading the scaler transformer
#include <Trade\Trade.mqh> //The trading module
#include <Trade\PositionInfo.mqh> //Position handling module

CCNNClassifier cnn;
RobustScaler scaler;
CTrade m_trade;
CPositionInfo m_position;

input string base_symbol = "EURUSD";
input string symbol_ = "USDJPY";
input ENUM_TIMEFRAMES timeframe = PERIOD_H4;
input uint window_ = 10;
input uint lookahead = 1;
input uint magic_number = 28042025;
input uint slippage = 100;

long classes_in_y_[] = {0, 1};
int OldNumBars = -1;
//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
//---
   
  if (!MQLInfoInteger(MQL_TESTER))
   if (!ChartSetSymbolPeriod(0, symbol_, timeframe))
      {
         printf("%s Failed to set symbol_ = %s and timeframe = %s, Error = %d",__FUNCTION__,symbol_,EnumToString(timeframe), GetLastError());
         return INIT_FAILED;
      }
   
//---
   
   string filename = StringFormat("basesymbol=%s.symbol=%s.model.%s.onnx",base_symbol, symbol_, EnumToString(timeframe));
   if (!cnn.Init(filename, ONNX_COMMON_FOLDER))
      {
         printf("%s failed to load a CNN model in ONNX format from the common folder '%s', Error = %d",__FUNCTION__,filename,GetLastError());
         return INIT_FAILED;
      }
      
//---
   
   filename = StringFormat("%s.%s.scaler.onnx", symbol_, EnumToString(timeframe));
   if (!scaler.Init(filename, ONNX_COMMON_FOLDER))
      {
         printf("%s failed to load a scaler in ONNX format from the common folder '%s', Error = %d",__FUNCTION__,filename,GetLastError());
         return INIT_FAILED;
      }
      
//---

   m_trade.SetExpertMagicNumber(magic_number);
   m_trade.SetDeviationInPoints(slippage);
   m_trade.SetMarginMode();
   m_trade.SetTypeFillingBySymbol(Symbol());
      
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---
   
   if (!isNewBar())
      return;
      
   CDataFrame x_df = getStationaryVars();
   
//--- Check if the number of rows received after indicator calculation is >= window size
   
   if ((uint)x_df.shape()[0]<window_)
      {
         printf("%s Fatal, Data received is less than the desired window=%u. Check your indicators or increase the number of bars in the function getSationaryVars()",__FUNCTION__,window_);
         DebugBreak();
         return;
      }

   
   ulong rows = (ulong)x_df.shape()[0];
   ulong cols = (ulong)x_df.shape()[1];
   
   //printf("Before scaled shape = (%I64u, %I64u)",rows, cols);
       
   matrix x = x_df.iloc((rows-window_), rows-1, 0, cols-1).m_values;
   
   matrix x_scaled = scaler.transform(x); //Transform the data, very important
   //printf("Xscaled shape = (%I64u, %I64u)",x_scaled.Rows(), x_scaled.Cols());
   
   long signal = cnn.predict(x_scaled, classes_in_y_).cls; //Predicted class
   
//--- Trading functionality

   MqlTick ticks;
   if (!SymbolInfoTick(Symbol(), ticks))
      {
         printf("Failed to obtain ticks information, Error = %d",GetLastError());
         return;
      }
      
   double volume_ = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
   
   if (signal == 1) //Check if there are is atleast a special pattern before opening a trade
     {        
        if (!PosExists(POSITION_TYPE_BUY) && !PosExists(POSITION_TYPE_SELL))  
            m_trade.Buy(volume_, Symbol(), ticks.ask,0,0);
     }
     
   if (signal == 0) //Check if there are is atleast a special pattern before opening a trade
     {        
        if (!PosExists(POSITION_TYPE_SELL) && !PosExists(POSITION_TYPE_BUY))  
            m_trade.Sell(volume_, Symbol(), ticks.bid,0,0);
     } 
  
    CloseTradeAfterTime((Timeframe2Minutes(Period())*lookahead)*60); //Close the trade after a certain lookahead and according the the trained timeframe
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
CDataFrame getStationaryVars(uint start = 1, uint bars = 50)
  {
    CDataFrame df; //Dataframe object
    
    vector open, high, low, close;
    open.CopyRates(Symbol(), Period(), COPY_RATES_OPEN, start, bars);
    high.CopyRates(Symbol(), Period(), COPY_RATES_HIGH, start, bars);
    low.CopyRates(Symbol(), Period(), COPY_RATES_LOW, start, bars);
    close.CopyRates(Symbol(), Period(), COPY_RATES_CLOSE, start, bars);
    
    vector pct_change = df.pct_change(close);
    vector diff_open = df.diff(open);
    vector diff_high = df.diff(high);
    vector diff_low = df.diff(low);
    vector diff_close = df.diff(close);
    
    df.insert("pct_change", pct_change);
    df.insert("diff_open", open);
    df.insert("diff_high", high);
    df.insert("diff_low", low);
    df.insert("diff_close", close);
    
    // Relative Strength Index (RSI)
    vector rsi = CMomentumIndicators::RSIIndicator(close);
    df.insert("rsi", rsi);
    
    // Stochastic Oscillator (Stoch)
    vector stock_k = CMomentumIndicators::StochasticOscillator(close,high,low).stoch;
    df.insert("stock_k", stock_k);
    
    // Moving Average Convergence Divergence (MACD)
    vector macd = COscillatorIndicators::MACDIndicator(close).main;
    df.insert("macd", macd);
    
    // Commodity Channel Index (CCI)
    vector cci = COscillatorIndicators::CCIIndicator(high,low,close);
    df.insert("cci", cci);
    
    // Rate of Change (ROC)
    vector roc = CMomentumIndicators::ROCIndicator(close);
    df.insert("roc", roc);
    
    // Ultimate Oscillator (UO)
    vector uo = CMomentumIndicators::UltimateOscillator(high,low,close);
    df.insert("uo", uo);
    
    // Williams %R
    vector williams_r = CMomentumIndicators::WilliamsR(high,low,close);
    df.insert("williams_r", williams_r);
    
    // Average True Range (ATR)
    vector atr = COscillatorIndicators::ATRIndicator(high,low,close);
    df.insert("atr", atr);
    
    // Awesome Oscillator (AO)
    vector ao = CMomentumIndicators::AwesomeOscillator(high,low);
    df.insert("ao", ao);
    
    // Average Directional Index (ADX)
    vector adx = COscillatorIndicators::ADXIndicator(high,low,close).adx;
    df.insert("adx", adx);
    
    // True Strength Index (TSI)
    vector tsi = CMomentumIndicators::TSIIndicator(close);
    df.insert("tsi", tsi);
    
    if (MQLInfoInteger(MQL_DEBUG))
      df.head();
    
    df = df.dropna(); //Drop not-a-number variables
    
    return df; //return the last rows = window from a dataframe which is the recent information fromthe market
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool PosExists(ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol()==Symbol() && m_position.Magic() == magic_number && m_position.PositionType()==type)
            return (true);
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool ClosePos(ENUM_POSITION_TYPE type)
 {
    for (int i=PositionsTotal()-1; i>=0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Symbol() == Symbol() && m_position.Magic() == magic_number && m_position.PositionType()==type)
            {
              if (m_trade.PositionClose(m_position.Ticket()))
                return true;
            }
            
    return (false);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void CloseTradeAfterTime(int period_seconds)
{
   for (int i = PositionsTotal() - 1; i >= 0; i--)
      if (m_position.SelectByIndex(i))
         if (m_position.Magic() == magic_number)
            if (TimeCurrent() - m_position.Time() >= period_seconds)
               m_trade.PositionClose(m_position.Ticket(), slippage);
}
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
int Timeframe2Minutes(ENUM_TIMEFRAMES tf)
{
   return int(PeriodSeconds(tf)/60.0);
}
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool isNewBar()
  {
   int CurrentNumBars = Bars(Symbol(), Period());
   if(OldNumBars!=CurrentNumBars)
     {
      OldNumBars = CurrentNumBars;
      return true;
     }
   return false;
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+


